#!/usr/bin/python3
##
## This file is part of the sigrok-util project.
##
## Copyright (C) 2013 Marcus Comstedt <marcus@mc.pp.se>
##
## This program is free software; you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation; either version 3 of the License, or
## (at your option) any later version.
##
## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
## GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program; if not, see <http://www.gnu.org/licenses/>.
##

import sys
import struct
import parseelf

class searcher:

    def reset(this, offs=0):
        if offs < 0 or offs > this.length:
            raise Exception('Reset past end of section')
        this.address = this.baseaddr + offs
        this.offset = offs

    def skip(this, cnt):
        if this.offset + cnt > this.length:
            raise Exception('Skip past end of section')
        this.address += cnt
        this.offset += cnt

    def peek(this, cnt, offs=0):
        if this.offset + offs + cnt > this.length:
            raise Exception('Peek past end of section')
        return this.data[this.offset + offs : this.offset + offs + cnt]

    def look_for(this, needle):
        pos = this.data.find(needle, this.offset)
        if pos < 0:
            raise Exception('Needle not found in haystack')
        this.skip(pos - this.offset)

    def look_for_either(this, needle1, needle2):
        pos1 = this.data.find(needle1, this.offset)
        pos2 = this.data.find(needle2, this.offset)
        if pos1 < 0 and pos2 < 0:
            raise Exception('Needle not found in haystack')
        if pos1 < 0 or pos2 < pos1:
            pos1 = pos2
        this.skip(pos1 - this.offset)

    def __init__(this, data, addr):
        this.data = data
        this.baseaddr = addr
        this.length = len(data)
        this.reset()

def search_plt_32(plt, addr):
    plt.reset()
    plt.look_for(struct.pack('<BBI', 0xff, 0x25, addr))	# jmp *addr32
    return plt.address

def search_plt_64(plt, addr):
    plt.reset()
    while True:
        plt.look_for(b'\xff\x25')		# jmpq *offs32(%rip)
        offs = struct.unpack('<i', plt.peek(4, 2))[0]
        if plt.address + offs + 6 == addr:
            return plt.address
        plt.skip(2)

def find_hex_file_lines_constructor_32(text, hex_file_lines_got, got_plt):
    while True:
        text.look_for_either(b'\x8b\xbb', b'\x8b\xb3')	# mov offs32(%ebx),{%edi,%esi}
        offs = struct.unpack('<i', text.peek(4, 2))[0]
        if got_plt + offs == hex_file_lines_got:
            text.skip(6)
            return
        text.skip(2)

def find_hex_file_lines_constructor_64(text, hex_file_lines_got):
    while True:
        text.look_for(b'\x48\x8b\x2d')		# mov offs32(%rip),%rbp
        offs = struct.unpack('<i', text.peek(4, 3))[0]
        if text.address + offs + 7 == hex_file_lines_got:
            text.skip(7)
            return
        text.skip(3)

def parse_hex_file_lines_constructor_32(text, basic_string_plt, got_plt, lines):
    text.skip(-5)
    esi = (text.peek(1) == b'\xb3')
    text.skip(5)
    cnt = len(lines)
    while cnt > 0:
        if text.peek(2) == b'\x8d\x45':		# lea offs8(%ebp),%eax
            text.skip(3)
        elif text.peek(2) == b'\x8d\x85':	# lea offs32(%ebp),%eax
            text.skip(6)
        if text.peek(1) == (b'\xbf' if esi else b'\xbe'):	# mov $imm32,%esi
            text.skip(5)
        elif text.peek(2) == (b'\x31\xff' if esi else b'\x31\xf6'):	# xor %esi,%esi
            text.skip(2)
        if text.peek(4) == b'\x89\x44\x24\x08':	# mov %eax,0x8(%esp)
            text.skip(4)
        if text.peek(2) == b'\x8d\x83':		# lea offs32(%ebx),%eax
            straddr = struct.unpack('<i', text.peek(4, 2))[0]
            text.skip(6)
            straddr += got_plt
        else:
            raise Exception('Expected lea offs32(%ebx),%eax @ ' +
                            ('0x%x' % text.address))
        if text.peek(4) == b'\x89\x44\x24\x04':	# mov %eax,0x4(%esp)
            text.skip(4)
        if text.peek(3) == (b'\x89\x34\x24' if esi else b'\x89\x3c\x24'):	# mov %edi,(%esp)
            offs = 0
            text.skip(3)
        elif text.peek(2) == (b'\x8d\x46' if esi else b'\x8d\x47'):	# lea offs8(%edi),%eax
            offs = struct.unpack('<b', text.peek(1, 2))[0]
            text.skip(3)
        elif text.peek(2) == (b'\x8d\x86' if esi else b'\x8d\x87'):	# lea offs32(%edi),%eax
            offs = struct.unpack('<i', text.peek(4, 2))[0]
            text.skip(6)
        else:
            raise Exception('Expected lea offs(%e'+('s' if esi else 'd')+'i),%eax @ ' +
                            ('0x%x' % text.address))
        if offs < 0 or offs > (len(lines) << 2) or (offs & 3) != 0:
            raise Exception('Invalid offset %d' % offs)
        index = offs >> 2
        if lines[index] != 0:
            raise Exception('Line %d filled multiple times' % index)
        if text.peek(3) == b'\x89\x04\x24':	# mov %eax,(%esp)
            text.skip(3)
        if text.peek(1) == b'\xe8':		# call offs32
            offs = struct.unpack('<i', text.peek(4, 1))[0]
            text.skip(5)
            if text.address + offs != basic_string_plt:
                raise Exception('Expected call ZNSsC1EPKcRKSaIcE@plt @ ' +
                                ('0x%x' % text.address))
        else:
            raise Exception('Expected call ZNSsC1EPKcRKSaIcE@plt @ ' +
                            ('0x%x' % text.address))
        if straddr == 0:
            raise Exception('NULL pointer stored to index %d' % index)
        lines[index] = straddr
        cnt -= 1

def parse_hex_file_lines_constructor_64(text, basic_string_plt, lines):
    cnt = len(lines)
    while cnt > 0:
        if text.peek(1) == b'\xbb':		# mov $imm32,%ebx
            text.skip(5)
        elif text.peek(2) == b'\x31\xdb':	# xor %ebx,%ebx
            text.skip(2)
        if text.peek(4) == b'\x48\x8d\x54\x24':	# lea offs8(%rsp),%rdx
            text.skip(5)
        elif text.peek(4) == b'\x48\x8d\x94\x24': # lea offs32(%rsp),%rdx
            text.skip(8)
        if text.peek(3) == b'\x48\x8d\x35':	# lea offs32(%rip),%rsi
            straddr = struct.unpack('<i', text.peek(4, 3))[0]
            text.skip(7)
            straddr += text.address
        else:
            raise Exception('Expected lea offs(%rip),%rsi @ ' +
                            ('0x%x' % text.address))
        if text.peek(3) == b'\x48\x89\xef':	# mov %rbp,%rdi
            offs = 0
            text.skip(3)
        elif text.peek(3) == b'\x48\x8d\x7d':	# lea offs8(%rbp),%rdi
            offs = struct.unpack('<b', text.peek(1, 3))[0]
            text.skip(4)
        elif text.peek(3) == b'\x48\x8d\xbd':	# lea offs32(%rbp),%rdi
            offs = struct.unpack('<i', text.peek(4, 3))[0]
            text.skip(7)
        else:
            raise Exception('Expected lea offs(%rbp),%rdi @ ' +
                            ('0x%x' % text.address))
        if text.peek(1) == b'\xbb':		# mov $imm32,%ebx
            text.skip(5)
        elif text.peek(2) == b'\x31\xdb':	# xor %ebx,%ebx
            text.skip(2)
        if offs < 0 or offs > (len(lines) << 3) or (offs & 7) != 0:
            raise Exception('Invalid offset %d' % offs)
        index = offs >> 3
        if lines[index] != 0:
            raise Exception('Line %d filled multiple times' % index)
        if text.peek(1) == b'\xe8':		# callq offs32
            offs = struct.unpack('<i', text.peek(4, 1))[0]
            text.skip(5)
            if text.address + offs != basic_string_plt:
                raise Exception('Expected callq ZNSsC1EPKcRKSaIcE@plt @ ' +
                                ('0x%x' % text.address))
        else:
            raise Exception('Expected callq ZNSsC1EPKcRKSaIcE@plt @ ' +
                            ('0x%x' % text.address))
        if straddr == 0:
            raise Exception('NULL pointer stored to index %d' % index)
        lines[index] = straddr
        cnt -= 1

def find_reloc(elf, symname):
    for section, relocs in elf.relocs.items():
        if 'symbols' in relocs and symname in relocs['symbols']:
            symnum = relocs['symbols'][symname]['number']
            for reloc in relocs['relocs']:
                if reloc['r_sym'] == symnum:
                    return reloc
    raise Exception('Unable to find a relocation against ' + symname)

def ihex_to_binary(lines):
    chunks = {}
    for line in lines:
        if line[0] != ':':
            raise Exception('ihex line does not start with ":"')
        line = bytes.fromhex(line[1:])
        if (sum(bytearray(line)) & 0xff) != 0:
            raise Exception('Invalid checksum in ihex')
        (byte_count, address, rectype) = struct.unpack('>BHB', line[:4])
        (data, checksum) = struct.unpack('>%dsB' % (byte_count), line[4:])
        if rectype == 1 and byte_count == 0:
            pass
        elif rectype != 0 or byte_count == 0:
            raise Exception('Unexpected rectype %d with bytecount %d' %
                            (rectype, byte_count))
        elif address in chunks:
            raise Exception('Multiple ihex lines with address 0x%x' % address)
        else:
            chunks[address] = data
    blob = b''
    for address in sorted(iter(chunks)):
        if address < len(blob):
            raise Exception('Overlapping ihex chunks')
        elif address > len(blob):
            blob += b'\x00' * (address - len(blob))
        blob += chunks[address]
    return blob

def extract_fx2_firmware(elf, symname, filename):
    blob = elf.load_symbol(elf.dynsym[symname + 'Count'])
    count = struct.unpack('<I', blob)[0]
    got_plt = elf.find_section('.got.plt')['sh_addr']
    hex_file_lines_got = find_reloc(elf, symname)['r_offset']
    basic_string_got = find_reloc(elf, '_ZNSsC1EPKcRKSaIcE')['r_offset']
    pltsec = elf.find_section('.plt')
    plt = searcher(elf.read_section(pltsec), pltsec['sh_addr'])
    try:
        if elf.elf_wordsize == 64:
            basic_string_plt = search_plt_64(plt, basic_string_got)
        else:
            basic_string_plt = search_plt_32(plt, basic_string_got)
    except:
        raise Exception('Unable to find a PLT entry for _ZNSsC1EPKcRKSaIcE')
    textsec = elf.find_section('.text')
    text = searcher(elf.read_section(textsec), textsec['sh_addr'])
    while True:
        try:
            if elf.elf_wordsize == 64:
                find_hex_file_lines_constructor_64(text, hex_file_lines_got)
            else:
                find_hex_file_lines_constructor_32(text, hex_file_lines_got,
                                                   got_plt)
        except:
            raise Exception('Unable to find constructor for ' + symname)
        oldoffs = text.offset
        l = [0]*count
        try:
            if elf.elf_wordsize == 64:
                parse_hex_file_lines_constructor_64(text, basic_string_plt, l)
            else:
                parse_hex_file_lines_constructor_32(text, basic_string_plt,
                                                    got_plt, l)
            break
        except KeyError:
            text.reset(oldoffs)
    rodatasec = elf.find_section('.rodata')
    rodata = elf.read_section(rodatasec)
    lo = rodatasec['sh_addr']
    hi = lo + rodatasec['sh_size']
    for i in range(count):
        addr = l[i]
        if addr < lo or addr >= hi:
            raise Exception('Address 0x%x outside of .rodata section' % addr)
        l[i] = elf.get_name(addr - lo, rodata)
    blob = ihex_to_binary(l)
    f = open(filename, 'wb')
    f.write(blob)
    f.close()
    print("saved %d bytes to %s" % (len(blob), filename))

def extract_fx2_firmware_single(elf, symname, filename):
    if not symname in elf.dynsym:
        return False
    hex = bytes.decode(elf.load_symbol(elf.dynsym[symname]))
    if hex[-1] == '\0':
        hex = hex[:-1]
    blob = ihex_to_binary(hex.split(';'))
    f = open(filename, 'wb')
    f.write(blob)
    f.close()
    print("saved %d bytes to %s" % (len(blob), filename))
    return True

def extract_fx3_firmware(elf, symname, filename):
    index = 0
    blobs = []
    while True:
        sym = elf.dynsym.get(symname + '_' + str(index))
        if not sym:
            break
        index += 1
        hex = bytes.decode(elf.load_symbol(sym))
        if hex[-1] == '\0':
            hex = hex[:-1]
        for part in hex.split(';'):
            if not part:
                continue
            if part[0] != ':':
                raise Exception('ihex line does not start with ":"')
            blobs.append(bytes.fromhex(part[1:]))
    if not blobs:
        return
    with open(filename, 'wb') as f:
        for blob in blobs:
            f.write(blob)
    print("saved %d bytes from %d blobs to %s" % (sum(map(len, blobs)),
        len(blobs), filename))

def extract_symbol(elf, symname, filename):
    blob = elf.load_symbol(elf.dynsym[symname])
    f = open(filename, 'wb')
    f.write(blob)
    f.close()
    print("saved %d bytes to %s" % (len(blob), filename))

def try_extract_symbol(elf, symname, filename):
    if not symname in elf.dynsym:
        return
    extract_symbol(elf, symname, filename)

def extract_bitstream(elf, lv):
    extract_symbol(elf, 'gLogic16Lv' + lv + 'CompressedBitstream',
                   'saleae-logic16-fpga-' + lv + '.bitstream')

def usage():
    print("sigrok-fwextract-saleae-logic16 <programfile>")
    sys.exit()


#
# main
#

if len(sys.argv) != 2:
    usage()

try:
    filename = sys.argv[1]
    elf = parseelf.elf(filename)
    if elf.ehdr['e_machine'] != 3 and elf.ehdr['e_machine'] != 62:
        raise Exception('Unsupported e_machine')
    if not extract_fx2_firmware_single(elf, 'Logic16FirmwareStrings', 'saleae-logic16-fx2.fw'):
        extract_fx2_firmware(elf, 'gLogic16HexFileLines', 'saleae-logic16-fx2.fw')
    extract_bitstream(elf, '18')
    extract_bitstream(elf, '33')
    extract_fx3_firmware(elf, 'LogicPro16FirmwareStrings', 'saleae-logicpro16-fx3.fw')
    extract_fx3_firmware(elf, 'LogicPro8FirmwareStrings', 'saleae-logicpro8-fx3.fw')
    try_extract_symbol(elf, 'gLogicPro16CompressedBitstream', 'saleae-logicpro16-fpga.bitstream')
    try_extract_symbol(elf, 'gLogicProCompressedBitstream', 'saleae-logicpro8-fpga.bitstream')
except Exception as e:
    print("Error: %s" % str(e))
